{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finding Mislabelled Samples through ResNet MNIST Training Process"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook trains a ResNet model using MNIST dataset and employed TrainIng Data analYzer (TIDY) method based on Forgetting Events algorithm, specifically `ForgettingEventsInterpreter`, to investigate the training process by recording the predictions in the process. Some samples are manually mislabelled and we are able to find them by looking into the predictions along the training. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import interpretdl as it"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a ResNet architecture for MNIST, the code is borrowed from [PaddlePaddle Official Documentation](https://www.paddlepaddle.org.cn/tutorials/projectdetail/1516124)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle.nn as nn\n",
    "import paddle.nn.functional as F\n",
    "\n",
    "class ConvBNLayer(paddle.nn.Layer):\n",
    "    def __init__(self,\n",
    "                 num_channels,\n",
    "                 num_filters,\n",
    "                 filter_size,\n",
    "                 stride=1,\n",
    "                 groups=1,\n",
    "                 act=None):\n",
    "        super(ConvBNLayer, self).__init__()\n",
    "\n",
    "        self._conv = nn.Conv2D(\n",
    "            in_channels=num_channels,\n",
    "            out_channels=num_filters,\n",
    "            kernel_size=filter_size,\n",
    "            stride=stride,\n",
    "            padding=(filter_size - 1) // 2,\n",
    "            groups=groups,\n",
    "            bias_attr=False)\n",
    "\n",
    "        self._batch_norm = paddle.nn.BatchNorm2D(num_filters)\n",
    "        \n",
    "        self.act = act\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        y = self._conv(inputs)\n",
    "        y = self._batch_norm(y)\n",
    "        if self.act == 'leaky':\n",
    "            y = F.leaky_relu(x=out, negative_slope=0.1)\n",
    "        elif self.act == 'relu':\n",
    "            y = F.relu(x=y)\n",
    "        return y\n",
    "\n",
    "class BottleneckBlock(paddle.nn.Layer):\n",
    "    def __init__(self,\n",
    "                 num_channels,\n",
    "                 num_filters,\n",
    "                 stride,\n",
    "                 shortcut=True):\n",
    "        super(BottleneckBlock, self).__init__()\n",
    "        self.conv0 = ConvBNLayer(\n",
    "            num_channels=num_channels,\n",
    "            num_filters=num_filters,\n",
    "            filter_size=1,\n",
    "            act='relu')\n",
    "        self.conv1 = ConvBNLayer(\n",
    "            num_channels=num_filters,\n",
    "            num_filters=num_filters,\n",
    "            filter_size=3,\n",
    "            stride=stride,\n",
    "            act='relu')\n",
    "        self.conv2 = ConvBNLayer(\n",
    "            num_channels=num_filters,\n",
    "            num_filters=num_filters * 4,\n",
    "            filter_size=1,\n",
    "            act=None)\n",
    "        if not shortcut:\n",
    "            self.short = ConvBNLayer(\n",
    "                num_channels=num_channels,\n",
    "                num_filters=num_filters * 4,\n",
    "                filter_size=1,\n",
    "                stride=stride)\n",
    "\n",
    "        self.shortcut = shortcut\n",
    "\n",
    "        self._num_channels_out = num_filters * 4\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        y = self.conv0(inputs)\n",
    "        conv1 = self.conv1(y)\n",
    "        conv2 = self.conv2(conv1)\n",
    "\n",
    "        if self.shortcut:\n",
    "            short = inputs\n",
    "        else:\n",
    "            short = self.short(inputs)\n",
    "\n",
    "        y = paddle.add(x=short, y=conv2)\n",
    "        y = F.relu(y)\n",
    "        return y\n",
    "\n",
    "class ResNet(paddle.nn.Layer):\n",
    "    def __init__(self, layers=50, class_dim=1):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.layers = layers\n",
    "        supported_layers = [50, 101, 152]\n",
    "        assert layers in supported_layers, \\\n",
    "            \"supported layers are {} but input layer is {}\".format(supported_layers, layers)\n",
    "\n",
    "        if layers == 50:\n",
    "            depth = [3, 4, 6, 3]\n",
    "        elif layers == 101:\n",
    "            depth = [3, 4, 23, 3]\n",
    "        elif layers == 152:\n",
    "            depth = [3, 8, 36, 3]\n",
    "        \n",
    "        num_filters = [64, 128, 256, 512]\n",
    "\n",
    "        self.conv = ConvBNLayer(\n",
    "            num_channels=1,\n",
    "            num_filters=64,\n",
    "            filter_size=7,\n",
    "            stride=2,\n",
    "            act='relu')\n",
    "        self.pool2d_max = nn.MaxPool2D(\n",
    "            kernel_size=3,\n",
    "            stride=2,\n",
    "            padding=1)\n",
    "\n",
    "        self.bottleneck_block_list = []\n",
    "        num_channels = 64\n",
    "        for block in range(len(depth)):\n",
    "            shortcut = False\n",
    "            for i in range(depth[block]):\n",
    "                bottleneck_block = self.add_sublayer(\n",
    "                    'bb_%d_%d' % (block, i),\n",
    "                    BottleneckBlock(\n",
    "                        num_channels=num_channels,\n",
    "                        num_filters=num_filters[block],\n",
    "                        stride=2 if i == 0 and block != 0 else 1, # c3、c4、c5将会在第一个残差块使用stride=2；其余所有残差块stride=1\n",
    "                        shortcut=shortcut))\n",
    "                num_channels = bottleneck_block._num_channels_out\n",
    "                self.bottleneck_block_list.append(bottleneck_block)\n",
    "                shortcut = True\n",
    "\n",
    "        self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D(output_size=1)\n",
    "\n",
    "        import math\n",
    "        stdv = 1.0 / math.sqrt(2048 * 1.0)\n",
    "        \n",
    "        self.out = nn.Linear(in_features=2048, out_features=class_dim,\n",
    "                      weight_attr=paddle.ParamAttr(\n",
    "                          initializer=paddle.nn.initializer.Uniform(-stdv, stdv)))\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        y = self.conv(inputs)\n",
    "        y = self.pool2d_max(y)\n",
    "        for bottleneck_block in self.bottleneck_block_list:\n",
    "            y = bottleneck_block(y)\n",
    "        y = self.pool2d_avg(y)\n",
    "        y = paddle.reshape(y, [y.shape[0], -1])\n",
    "        y = self.out(y)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the MNIST dataset generator from **paddle.vision** to get the labels and manually mislabel 1% samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from paddle.vision.transforms import ToTensor, Resize, Compose\n",
    "from paddle.vision.datasets import MNIST\n",
    "\n",
    "train_dataset = MNIST(mode='train', transform=Compose([Resize(size=32), ToTensor()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare manually mislabelled samples\n",
    "labels = []\n",
    "for i in range(0, 60000, 100):\n",
    "    labels.append(np.random.choice(np.delete(np.arange(10), train_dataset[i][-1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ResNet(class_dim=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a new data generator based on MNIST data generator. It replaces 1% true labels by the wrong ones. \n",
    "\n",
    "**Important:** the data generator shoud generate the index of each sample as the first element so that each sample's behavior can be recorded according to its index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reader_prepare(dataset, new_labels):\n",
    "    def reader():\n",
    "        idx = 0\n",
    "        for data, label in dataset:\n",
    "            if idx % 100 == 0:\n",
    "                label = new_labels[idx // 100]\n",
    "            yield idx, data, int(label)\n",
    "            idx += 1\n",
    "    return reader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up a data loader with batch size of 128, and an Momentum optimizer for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128\n",
    "train_reader = paddle.batch(\n",
    "    reader_prepare(train_dataset, labels), batch_size=BATCH_SIZE)\n",
    "optimizer = paddle.optimizer.Momentum(learning_rate=0.001,\n",
    "                     momentum=0.9,\n",
    "                     parameters=model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First initialize the `ForgettingEventsInterpreter` and then start `interpret`ing the training process by training 100 epochs. \n",
    "\n",
    "*stats* is a dictionary that maps image index to predictions in the training process and if they are correct; *noisy_samples* is a list of mislabelled image ids. *stats* is saved at \"assets/stats.pkl\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training 100 epochs. This may take some time.\n",
      "| Epoch [  1/100] Iter[  2]\t\tLoss: 2.5311 Acc@1: 10.938%"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:636: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Epoch [100/100] Iter[469]\t\tLoss: 0.0000 Acc@1: 100.000%"
     ]
    }
   ],
   "source": [
    "fe = it.ForgettingEventsInterpreter(model, True)\n",
    "\n",
    "epochs = 100\n",
    "print('Training %d epochs. This may take some time.' % epochs)\n",
    "stats, noisy_samples = fe.interpret(\n",
    "    train_reader,\n",
    "    optimizer,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    epochs=epochs,\n",
    "    noisy_labels=True,\n",
    "    save_path='assets')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate the recall, precision and F1 for our found noisy samples. \n",
    "\n",
    "88.7% of mislabelled samples have been found and among those samples found, 80.1% are indeed mislabelled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall:  0.8866666666666667\n",
      "Precision:  0.8012048192771084\n",
      "F1 Score:  0.8417721518987342\n"
     ]
    }
   ],
   "source": [
    "recall = np.sum([id_ % 100 == 0 for id_ in noisy_samples]) / (60000 / 100)\n",
    "precision = np.sum([id_ % 100 == 0 for id_ in noisy_samples]) / len(noisy_samples)\n",
    "print('Recall: ', recall)\n",
    "print('Precision: ', precision)\n",
    "print('F1 Score: ', 2 * (recall * precision) / (recall + precision))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "paddle2.0",
   "language": "python",
   "name": "paddle2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
